context("indexing")

test_that("[ works", {
  x <- torch_randn(c(10, 10, 10))
  expect_equal(as_array(x[1, 1, 1]), as_array(x)[1, 1, 1])
  expect_equal(as_array(x[1, , ]), as_array(x)[1, , ])
  expect_equal(as_array(x[1:5, , ]), as_array(x)[1:5, , ])
  expect_equal(as_array(x[1:10:2, , ]), as_array(x)[seq(1, 10, by = 2), , ])

  x <- torch_tensor(0:9)
  expect_equal(as_array(x[-1]$to(dtype = torch_int())), 9)
  expect_equal(as_array(x[-2:10]$to(dtype = torch_int())), c(8, 9))
  expect_equal(as_array(x[2:N]$to(dtype = torch_int())), c(1:9))

  x <- torch_randn(c(10, 10, 10, 10))
  expect_equal(as_array(x[1, ..]), as_array(x)[1, , , ])
  expect_equal(as_array(x[1, 1, ..]), as_array(x)[1, 1, , ])
  expect_equal(as_array(x[.., 1]), as_array(x)[, , , 1])
  expect_equal(as_array(x[.., 1, 1]), as_array(x)[, , 1, 1])

  x <- torch_randn(c(10, 10, 10, 10))
  i <- c(1, 2, 3, 4)
  expect_equal(as_array(x[!!!i]), as_array(x)[1, 2, 3, 4])
  i <- c(1, 2)
  expect_equal(as_array(x[!!!i, 3, 4]), as_array(x)[1, 2, 3, 4])

  x <- torch_tensor(1:10)
  y <- 1:10
  expect_equal_to_r(x[c(1, 3, 2, 5)]$to(dtype = torch_int()), y[c(1, 3, 2, 5)])

  index <- 1:3
  expect_equal_to_r(x[index]$to(dtype = torch_int()), y[index])

  x <- torch_randn(10, 10)
  x[c(2, 3, 1), c(3, 2, 1)]
  expect_equal_to_r(x[c(2, 3, 1), c(3, 2, 1)], as_array(x)[c(2, 3, 1), c(3, 2, 1)])

  x <- torch_randn(10)
  expect_equal_to_tensor(x[1:5, ..], x[1:5])

  x <- torch_randn(10)
  expect_tensor_shape(x[, NULL], c(10, 1))
  expect_tensor_shape(x[NULL, , NULL], c(1, 10, 1))
  expect_tensor_shape(x[NULL, , NULL, NULL], c(1, 10, 1, 1))

  x <- torch_randn(10)
  expect_tensor_shape(x[, newaxis], c(10, 1))
  expect_tensor_shape(x[newaxis, , newaxis], c(1, 10, 1))
  expect_tensor_shape(x[newaxis, , newaxis, newaxis], c(1, 10, 1, 1))

  x <- torch_randn(10, 10)
  expect_tensor_shape(x[1, , drop = FALSE], c(1, 10))
  expect_tensor_shape(x[.., 1, drop = FALSE], c(10, 1))
  expect_tensor_shape(x[.., -1, drop = FALSE], c(10, 1))
})

test_that("indexing error expectations", {
  x <- torch_randn(c(10, 10, 10, 10))
  expect_error(x[1, 1, 1, 1, 1])
  x <- torch_tensor(10)
  expect_error(x[0])
  expect_error(x[c(0, 1)])
})

test_that("indexing with boolean tensor", {
  x <- torch_tensor(c(-1, -2, 0, 1, 2))
  expect_equal_to_r(x[x < 0], c(-1, -2))

  x <- torch_tensor(rbind(
    c(-1, -2, 0, 1, 2),
    c(2, 1, 0, -1, -2)
  ))

  expect_equal_to_r(x[x < 0], c(-1, -2, -1, -2))

  expect_error(x[x < 0, 1])
})

test_that("slice with negative indexes", {
  x <- torch_tensor(c(1, 2, 3))
  expect_equal_to_r(x[2:-1], c(2, 3))
  expect_equal_to_r(x[-2:-1], c(2, 3))
  expect_equal_to_r(x[-3:-2], c(1, 2))

  expect_equal_to_r(x[c(-1, -2)], c(3, 2))
})

test_that("subset assignment", {
  x <- torch_randn(2, 2)
  x[1, 1] <- torch_tensor(0)
  x
  expect_equal_to_r(x[1, 1], 0)

  x[1, 2] <- 0
  expect_equal_to_r(x[1, 2], 0)

  x[1, 2] <- 1L
  expect_equal_to_r(x[1, 2], 1)

  x <- torch_tensor(c(TRUE, FALSE))
  x[2] <- TRUE
  expect_equal_to_r(x[2], TRUE)

  x <- torch_tensor(rbind(
    c(-1, -2, 0, 1, 2),
    c(2, 1, 0, -1, -2)
  ))

  x[x <= 0] <- 1
  expect_true(as_array(torch_all(x > 0)))

  x <- torch_tensor(c(1, 2, 3, 4, 5))
  x[1:2] <- c(0, 0)
  expect_equal_to_r(x[1:2], c(0, 0))
})

test_that("indexing with R boolean vectors", {
  x <- torch_tensor(c(1, 2))
  expect_equal_to_r(x[TRUE], matrix(c(1, 2), nrow = 1))
  expect_equal_to_r(x[FALSE], matrix(data = 1, ncol = 2, nrow = 0))
  expect_equal_to_r(x[c(TRUE, FALSE)], 1)
})

test_that("indexing with long tensors", {
  x <- torch_randn(4, 4)
  index <- torch_tensor(1, dtype = torch_long())
  expect_equal(x[index, index]$item(), x[1, 1]$item())
  expect_tensor_shape(x[index, index], c(1, 1))

  index <- torch_scalar_tensor(1, dtype = torch_long())
  expect_equal_to_tensor(x[index, index], x[1, 1])

  index <- torch_tensor(-1, dtype = torch_long())
  expect_equal(x[index, index]$item(), x[-1, -1]$item())
  expect_tensor_shape(x[index, index], c(1, 1))

  index <- torch_scalar_tensor(-1, dtype = torch_long())
  expect_equal_to_tensor(x[index, index], x[-1, -1])

  index <- torch_tensor(c(-1, 1), dtype = torch_long())
  expect_equal_to_tensor(x[index, index], x[c(-1, 1), c(-1, 1)])

  index <- torch_tensor(c(-1, 0, 1), dtype = torch_long())
  expect_error(x[index, ], regexp = "Indexing starts at 1")
})

test_that("can use the slc construct", {
  x <- torch_randn(10, 10)
  r <- as_array(x)

  expect_equal_to_r(
    x[slc(start = 1, end = 5, step = 2), ],
    r[seq(1, 5, by = 2), ]
  )

  expect_equal_to_r(
    x[slc(start = 1, end = 5, step = 2), 1],
    r[seq(1, 5, by = 2), 1]
  )

  expect_equal_to_r(
    x[slc(start = 1, end = 5, step = 2), slc(start = 1, end = 5, step = 2)],
    r[seq(1, 5, by = 2), seq(1, 5, by = 2)]
  )

  expect_equal_to_tensor(
    x[slc(2, Inf), ],
    x[2:N, ]
  )
})

test_that("print slice", {
  testthat::local_edition(3)
  expect_snapshot(print(slc(1, 3, 5)))
})

test_that("mix vector indexing with slices and others", {
  x <- torch_randn(3, 3, 3)
  expect_equal_to_tensor(
    x[c(1, 2), 1:2, c(1, 2)],
    x[1:2, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), newaxis, 1:2, c(1, 2)],
    x[1:2, newaxis, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[newaxis, c(1, 2), newaxis, 1:2, c(1, 2)],
    x[newaxis, 1:2, newaxis, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), c(1, 2), ],
    x[1:2, 1:2, ]
  )

  expect_equal_to_tensor(
    x[c(1, 2), , c(1, 2)],
    x[1:2, , 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), c(1, 2), c(1, 2)],
    x[1:2, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), c(1, 2), newaxis, c(1, 2)],
    x[1:2, 1:2, newaxis, 1:2]
  )
})

test_that("boolean vector indexing works as expected", {
  x <- torch_randn(4, 4, 4)
  index <- c(TRUE, FALSE, TRUE, FALSE)

  expect_equal_to_r(
    x[index, index, index],
    as_array(x)[index, index, index]
  )
})

test_that("regression test for #691", {
  a <- torch_randn(c(6, 4))
  b <- c(1, 2, 3)
  a[b]
  expect_equal(b, c(1, 2, 3))
})

test_that("regression test for #695", {
  a <- torch_randn(c(3, 4, 2))
  b <- torch_tensor(c(1, 3), dtype = torch_long())

  expect_equal_to_r(
    a[.., b, ],
    as.array(a)[, c(1, 3), ]
  )

  a <- torch_randn(c(3, 4, 3))

  expect_equal_to_r(
    a[.., b, b],
    as.array(a)[, c(1, 3), c(1, 3)]
  )

  expect_equal_to_r(
    a[b, .., b],
    as.array(a)[c(1, 3), , c(1, 3)]
  )
})

test_that("NULL tensor", {
  
  x <- torch_tensor(NULL)
  expect_true(x$dtype == torch_bool())
  expect_equal(x$shape, 0)
  
  # subsetting shouldn't crash
  expect_error(x[1], regexp = "out of bounds")
  expect_error(torch_tensor(as.integer(NULL))[1], regexp = "out of bounds")
  
})
